// @HEADER
//*********************************************************************//
//  SiYuan: A numerical PDE solver                                     //
//  Copyright (2022) YUAN Xi                                           //
//  This Software is released under the BSD 2-Clause license detailed  //
//  in the file "LICENSE" in the top-level SiYuan directory            //
//*********************************************************************//
// @HEADER

#ifndef _SchrodingerEquation_impl_hpp_
#define _SchrodingerEquation_impl_hpp_

#include "Teuchos_ParameterList.hpp"
#include "Teuchos_StandardParameterEntryValidators.hpp"
#include "Teuchos_RCP.hpp"
#include "Teuchos_Assert.hpp"
#include "Phalanx_DataLayout_MDALayout.hpp"
#include "Phalanx_FieldManager.hpp"

#include "Panzer_IntegrationRule.hpp"
#include "Panzer_BasisIRLayout.hpp"

// include evaluators here
#include "Panzer_Integrator_BasisTimesScalar.hpp"
#include "Panzer_Integrator_TransientBasisTimesScalar.hpp"
#include "Panzer_Integrator_BasisTimesVector.hpp"
#include "Panzer_Integrator_BasisTimesTensorTimesVector.hpp"
#include "Panzer_Product.hpp"

#include "SchrodingerPotential.hpp"

// ***********************************************************************
template <typename EvalT>
SiYuan::EquationSet_Schrodinger<EvalT>::
EquationSet_Schrodinger(const Teuchos::RCP<Teuchos::ParameterList>& params,
    const int& default_integration_order,
    const panzer::CellData& cell_data,
    const Teuchos::RCP<panzer::GlobalData>& global_data,
    const bool build_transient_support) :
    panzer::EquationSet_DefaultImpl<EvalT>(params,default_integration_order,cell_data,global_data,build_transient_support ) {
  // ********************
  // Validate and parse parameter list
  // ********************
  TEUCHOS_TEST_FOR_EXCEPTION(!params->isSublist("Potential"),std::logic_error,"Potential not defined!\n");
  {
    Teuchos::ParameterList valid_parameters;
    this->setDefaultValidParameters(valid_parameters);

    valid_parameters.set("Energy unit in eV",1.0,"Energy Unit In Electron Volts");  //default eV
    valid_parameters.set("Length unit in m",1.e-6,"Length Unit In Meters");         //default to um

    valid_parameters.set("Model ID","","Closure model id associated with this equation set");
    valid_parameters.set("Basis Order",1,"Order of the basis");
    valid_parameters.set("Integration Order",2,"Order of the integration");
    valid_parameters.sublist("Potential").disableRecursiveValidation();

    params->validateParametersAndSetDefaults(valid_parameters);
  }

  std::string model_id = params->get<std::string>("Model ID");
  m_dof_name = "WaveFunction";
  std::string basis_type = "HGrad";
  int basis_order = params->get<int>("Basis Order");
  int integration_order = params->get<int>("Integration Order");
  pl_potential = params->sublist("Potential");
  energy_unit_in_eV = params->get<double>("Energy unit in eV");
  length_unit_in_m = params->get<double>("Length unit in m");

  // ********************
  // Setup DOFs and closure models
  // ********************
  {
    this->addDOF(m_dof_name,basis_type,basis_order,integration_order);
    this->addDOFGrad(m_dof_name);
    this->addDOFTimeDerivative(m_dof_name);
  }

  this->addClosureModel(model_id);
  this->setupDOFs();
}

// ***********************************************************************
template <typename EvalT>
void SiYuan::EquationSet_Schrodinger<EvalT>::
buildAndRegisterEquationSetEvaluators(PHX::FieldManager<panzer::Traits>& fm,
    const panzer::FieldLibrary& /* fl */,
    const Teuchos::ParameterList& ) const
{
  using panzer::EvaluatorStyle;

  Teuchos::RCP<panzer::IntegrationRule> ir = this->getIntRuleForDOF(m_dof_name); 
  Teuchos::RCP<panzer::BasisIRLayout> basis = this->getBasisIRLayoutForDOF(m_dof_name);

  Teuchos::RCP<Teuchos::ParameterList> pl = Teuchos::rcp(new Teuchos::ParameterList);
  //Global Problem Parameters
  pl->set<double>("Energy unit in eV", energy_unit_in_eV);
  pl->set<double>("Length unit in m", length_unit_in_m);

  // Potential
  { 
      std::string potentialType = pl_potential.get<std::string>("Type");
      pl->set<std::string>("Type", potentialType);

      Teuchos::RCP<PHX::Evaluator<panzer::Traits>> ev = Teuchos::rcp( new
        SiYuan::SchrodingerPotential<EvalT, panzer::Traits>(*pl,*ir) );
      fm.template registerEvaluator<EvalT>(ev);
  }

  // Residual: Transient Operator
  {
    std::string resName("RESIDUAL_" + m_dof_name);
    std::string valName("DXDT_" + m_dof_name);
  /*  double multiplier(1);
    Teuchos::RCP<PHX::Evaluator<panzer::Traits>> top =Teuchos:: rcp(new
      panzer::Integrator_BasisTimesScalar<EvalT, panzer::Traits>(EvaluatorStyle::EVALUATES,
        resName, valName, *basis, *ir, multiplier) );*/
    Teuchos::RCP<PHX::Evaluator<panzer::Traits>> top =Teuchos:: rcp(new
      Mass_Schrodinger<EvalT, panzer::Traits>(EvaluatorStyle::EVALUATES,
        resName, valName, *basis, *ir) );
    this->template registerEvaluator<EvalT>(fm, top);
  }

  // Residual: Gradient term
  {
    std::string resName("RESIDUAL_" + m_dof_name);
    std::string gradName("GRAD_" + m_dof_name);
    Teuchos::RCP<PHX::Evaluator<panzer::Traits>> op = Teuchos::rcp( new
      SiYuan::Residual_Schrodinger<EvalT, panzer::Traits>(*pl, panzer::EvaluatorStyle::CONTRIBUTES,
        resName, m_dof_name, gradName, *basis, *ir) );
    this->template registerEvaluator<EvalT>(fm, op);
  }

  
}

// ***********************************************************************

// ***********************************************************************
/////////////////////////////////////////////////////////////////////////////
//
//  Main Constructor: Residual_Schrodinger
//
/////////////////////////////////////////////////////////////////////////////
template<typename EvalT, typename Traits>
SiYuan::Residual_Schrodinger<EvalT, Traits>::
Residual_Schrodinger(
    const Teuchos::ParameterList& p,
    const panzer::EvaluatorStyle&   evalStyle,
    const std::string&              resName,
    const std::string&              valName,
    const std::string&              gradName,
    const panzer::BasisIRLayout&    basis,
    const panzer::IntegrationRule&  ir )
    :
    evalStyle_(evalStyle),
    basisName_(basis.name())
{
    double energy_unit_in_eV = p.get<double>("Energy unit in eV");
    double length_unit_in_m = p.get<double>("Length unit in m");

    // calculate hbar^2/2m0 so kinetic energy has specified units (EnergyUnitInEV)
    const double hbar = 1.0546e-34;   // Planck constant [J s]
    const double evPerJ = 6.2415e18;  // eV per Joule (eV/J)
    const double emass = 9.1094e-31;  // Electron mass [kg]
    hbar2_over_2m0 = 0.5*hbar*hbar*evPerJ /(emass *energy_unit_in_eV *pow(length_unit_in_m,2));
hbar2_over_2m0 =1.0;
    psi_ = PHX::MDField<const ScalarT, panzer::Cell, panzer::IP>(valName, ir.dl_scalar);
    this->addDependentField(psi_);
    psiGrad_ = PHX::MDField<const ScalarT, panzer::Cell, panzer::IP, panzer::Dim>(gradName, ir.dl_vector);
    this->addDependentField(psiGrad_);
    std::string potentialName("Potential");
    V_ = PHX::MDField<const ScalarT, panzer::Cell, panzer::IP>(potentialName, ir.dl_scalar);
    this->addDependentField(V_);

    // Create the field that we're either contributing to or evaluating
    // (storing).
    field_ = PHX::MDField<ScalarT, panzer::Cell, panzer::BASIS>(resName, basis.functional);
    if (evalStyle == panzer::EvaluatorStyle::CONTRIBUTES)
      this->addContributedField(field_);
    else // if (evalStyle == EvaluatorStyle::EVALUATES)
      this->addEvaluatedField(field_);

    // Set the name of this object.
    std::string n("Integrator_WaveFunction (");
    if (evalStyle == panzer::EvaluatorStyle::CONTRIBUTES)
      n += "CONTRIBUTES";
    else // if (evalStyle == EvaluatorStyle::EVALUATES)
      n += "EVALUATES";
    n += "):  " + field_.fieldTag().name();
    this->setName(n);
} // end of Main Constructor

/////////////////////////////////////////////////////////////////////////////
//
//  postRegistrationSetup()
//
/////////////////////////////////////////////////////////////////////////////
template<typename EvalT, typename Traits>
void
SiYuan::Residual_Schrodinger<EvalT, Traits>:: postRegistrationSetup(
    typename Traits::SetupData sd, PHX::FieldManager<Traits>& fm)
{
    basisIndex_ = panzer::getBasisIndex(basisName_, (*sd.worksets_)[0], this->wda);
} // end of postRegistrationSetup()

/////////////////////////////////////////////////////////////////////////////
//
//  evaluateFields()
//
/////////////////////////////////////////////////////////////////////////////
template<typename EvalT, typename Traits>
void
SiYuan::Residual_Schrodinger<EvalT, Traits>::
evaluateFields( typename Traits::EvalData workset )
{
    using Kokkos::parallel_for;
    using Kokkos::RangePolicy;
    using PHX::Device;

    typedef Intrepid2::FunctionSpaceTools<PHX::exec_space> FST;

    if( workset.beta==0.0 ) return;
//std::cout << "RES:" << PHX::print<EvalT>()  << std::endl;
    // Grab the basis information.
    const panzer::BasisValues2<double>& bv = *this->wda(workset).bases[basisIndex_];

    PHX::MDField<double, panzer::Cell, panzer::BASIS, panzer::IP, panzer::Dim>
          weightedGradBasis = bv.weighted_grad_basis;
    PHX::MDField<double, panzer::Cell, panzer::BASIS, panzer::IP>
          weightedBasis = bv.weighted_basis_scalar;
 
    int numDims = weightedGradBasis.extent(0);
    int numBases= weightedGradBasis.extent(1);
    int numQP = weightedGradBasis.extent(2);
    for( int cell=0; cell<workset.num_cells; cell++ )
    {
        for (int basis(0); basis < numBases; ++basis)
        {
          if (evalStyle_ == panzer::EvaluatorStyle::EVALUATES)
            field_(cell, basis) = 0.0;
          for (int qp(0); qp < numQP; ++qp) {
            for (int dim(0); dim<numDims; ++dim ) {
              field_(cell, basis) += hbar2_over_2m0 * psiGrad_(cell, qp, dim) *
                weightedGradBasis(cell, basis, qp, dim);
            }
            field_(cell, basis) -= psi_(cell, qp)* V_(cell, qp)* weightedBasis(cell, basis, qp);
          //  std::cout << cell << "," << basis << ", " << field_(cell, basis) << std::endl;
          }
        } // end loop over the basis functions
    }

} // end of evaluateFields()

// ***********************************************************************
/////////////////////////////////////////////////////////////////////////////
//
//  Main Constructor: Mass_Schrodinger
//
/////////////////////////////////////////////////////////////////////////////
template<typename EvalT, typename Traits>
SiYuan::Mass_Schrodinger<EvalT, Traits>::
Mass_Schrodinger( const panzer::EvaluatorStyle&   evalStyle,
        const std::string&              resName,
        const std::string&              valName,
        const panzer::BasisIRLayout&    basis,
        const panzer::IntegrationRule&  ir )
    :
    evalStyle_(evalStyle),
    basisName_(basis.name())
{
    DXDT_psi_ = PHX::MDField<const ScalarT, panzer::Cell, panzer::IP>(valName, ir.dl_scalar);
    this->addDependentField(DXDT_psi_);

    // Create the field that we're either contributing to or evaluating
    // (storing).
    field_ = PHX::MDField<ScalarT, panzer::Cell, panzer::BASIS>(resName, basis.functional);
    if (evalStyle == panzer::EvaluatorStyle::CONTRIBUTES)
      this->addContributedField(field_);
    else // if (evalStyle == EvaluatorStyle::EVALUATES)
      this->addEvaluatedField(field_);

    // Set the name of this object.
    std::string n("Mass_WaveFunction (");
    if (evalStyle == panzer::EvaluatorStyle::CONTRIBUTES)
      n += "CONTRIBUTES";
    else // if (evalStyle == EvaluatorStyle::EVALUATES)
      n += "EVALUATES";
    n += "):  " + field_.fieldTag().name();
    this->setName(n);
} // end of Main Constructor

/////////////////////////////////////////////////////////////////////////////
//
//  postRegistrationSetup()
//
/////////////////////////////////////////////////////////////////////////////
template<typename EvalT, typename Traits>
void
SiYuan::Mass_Schrodinger<EvalT, Traits>:: postRegistrationSetup(
    typename Traits::SetupData sd, PHX::FieldManager<Traits>& fm)
{
    basisIndex_ = panzer::getBasisIndex(basisName_, (*sd.worksets_)[0], this->wda);
} // end of postRegistrationSetup()

/////////////////////////////////////////////////////////////////////////////
//
//  evaluateFields()
//
/////////////////////////////////////////////////////////////////////////////
template<typename EvalT, typename Traits>
void
SiYuan::Mass_Schrodinger<EvalT, Traits>::
evaluateFields( typename Traits::EvalData workset )
{
    using Kokkos::parallel_for;
    using Kokkos::RangePolicy;
    using PHX::Device;

    typedef Intrepid2::FunctionSpaceTools<PHX::exec_space> FST;

  //  if( workset.alpha==0.0 ) return;
//std::cout << "MASS:" << PHX::print<EvalT>() << std::endl;
    // Grab the basis information.
    const panzer::BasisValues2<double>& bv = *this->wda(workset).bases[basisIndex_];

    PHX::MDField<double, panzer::Cell, panzer::BASIS, panzer::IP>
          weightedBasis = bv.weighted_basis_scalar;
 
    int numBases= weightedBasis.extent(1);
    int numQP = weightedBasis.extent(2);
    for( int cell=0; cell<workset.num_cells; cell++ )
    {
        for (int basis(0); basis < numBases; ++basis)
        {
          if (evalStyle_ == panzer::EvaluatorStyle::EVALUATES)
            field_(cell, basis) = 0.0;
          for (int qp(0); qp < numQP; ++qp) {
            field_(cell, basis) += DXDT_psi_(cell, qp)* weightedBasis(cell, basis, qp);
          }
        //  std::cout << cell << "," << basis << ", " << field_(cell, basis) << std::endl;
        } // end loop over the basis functions
    }

} // end of evaluateFields()

#endif 
